#!/usr/bin/env python3

class ListNode:
    def __init__(self, val):
        self.val = val
        self.next = None

    def link(self, next):
        self.next = next
        return next

    def print(self):
        print(str(self.val), end=' ')
        if self.next:
            self.next.print()

class Solution:
    def remove_nth_from_end(self, head, n):
        if head.next is None:
            nth = 1
        else:
            result = self.remove_nth_from_end(head.next, n)
            if type(result) is int:
                nth = result + 1
            else:
                head.next = result
                return head
        if nth == n:
            return head.next
        else:
            return nth
    

if __name__ == '__main__':
    linked = ListNode(1)
    linked.link(ListNode(2))\
          .link(ListNode(3))\
          .link(ListNode(4))\
          .link(ListNode(5))

    solution = Solution()
    solution.remove_nth_from_end(linked, 2)
    linked.print()
    print()
